Project by: Anna Wang, Jimmy Yang and Philip Chen
This project aims to explore the applications of using a Generative Adversarial Network (GANs) to remove objects and inpaint over specific regions in images. The art of object inpainting is a technique used to seamlessly insert or remove objects within an image, often used for image editing or restoration. Given a masked area, we wish to fill in the masked-area with new content that blends realistically with the rest of the image.
Object removal and image inpainting are inherently challenging tasks in computer vision. To help us tackle this project, we employ the use of a special variant of GANs, a Deep Convolutional Generative Adversarial Network (DCGANs). With the help of convolutional layers for image processing, we architected a generator to produce convincing image completions, while the discriminator learns to distinguish between authentic images and the generator's images. As the two models compete in an adversarial process, they both improve their respective abilities, leading to a better trained generator to inpaint masked images.
To evaluate performance, a curated dataset of cat images was used, with manually masked holes of the same size inserted in the middle of all images. This choice of dataset allows for testing the model's ability to handle fine textures like fur and preserve anatomical consistency. The generator performs reasonably well at reconstructing altered parts of the cats body, but struggles when it comes to facial features.
Ultimately, this project highlighted the complexity of the inpainting task. We conclude with several observations and potential directions for improving the model's performance in future work.
Anna Wang ( , @uwaterloo.ca):
Jimmy Yang (20890430, jj7yang@uwaterloo.ca): I helped curate the dataset of cats to be used in model training. I worked with Anna to split our original dataset into images of cats and defined a custom CatDataset class to help us insert square masks for training purposes. I also explored various model training methods against our dataset to help decide what image resolution to use and how to architect our models. Then, together with the group I contributed to fine-tuning our models, analyzing our results, and iterating on improvements.
Philip Chen ( , @uwaterloo.ca):
The following external Python libraries were essential to the development and implementation of this project:
NumPy: Used for efficient numerical operations and array manipulation throughout the project, particularly when handling image masks, pixel data, and preprocessing pipelines.
Matplotlib: Used for visualizing training progress, loss curves, and before-and-after comparisons of image inpainting results. It plays a key role in debugging and showcasing the model's performance.
PyTorch: Foundational deep learning framework for this project. It enabled the construction and training of the GAN architecture through its neural network modules and optimization tools.
# Standard libraries for math and plotting
import numpy as np
import matplotlib.pyplot as plt
# Standard PyTorch for network training
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import v2
import torchvision.utils as vutils
# Other utilities and misc
import os
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
Our dataset comes from Hugging Face Dataset cats_vs_dogs: https://huggingface.co/datasets/microsoft/cats_vs_dogs. From the dataset, we filter for images of cats and split up training, validation, and testing sets.
We install the standard datasets library to use the dataset in our training.
# uncomment cli login if need to login
# !huggingface-cli login
!pip install datasets
from datasets import load_dataset
ds = load_dataset("microsoft/cats_vs_dogs")
print(ds['train'])
# filtering data by cats and dogs
data = ds['train']
dogs = data.filter(lambda isDog: isDog['labels'] == 1)
cats = data.filter(lambda isDog: isDog['labels'] == 0)
print("Dataset size: " + str(len(data)))
print("Total dataset size:", len(data))
print("Dogs:", len(dogs))
print("Cats:", len(cats))
# Split the dataset
train_test_split = cats.train_test_split(test_size=0.2, seed=42)
test_val_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)
train_dataset = train_test_split['train']['image']
val_dataset = test_val_split['train']['image']
test_dataset = test_val_split['test']['image']
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
train_dataset[0]
In order to create training samples for our model, we create the function create_masks that cuts out the center of an image of size hole_size. We acknowledge that designing a network model for inpainting has a number of challenges. To simplify the project, we assume that the holes are always square, and that they are always in the center and have the same size. The central region is surrounded by image context on all four sides—top, bottom, left, and right. This gives the model rich, balanced context to infer what the missing content should look like. During training, consistently masking the center allows the model to focus its capacity on learning to inpaint a specific region. This makes convergence faster and often leads to better initial performance compared to random or irregular masking. Future improvements to the model would be to allow for custom created masks.
def create_masks(num_masks):
hole_size = 20
im_size = 64
x = int((im_size - hole_size) / 2.0)
y = int((im_size - hole_size) / 2.0)
mask = torch.zeros((1, im_size, im_size))
mask[0, y : y + hole_size, x : x + hole_size] = 1
masks = mask.repeat_interleave(num_masks, dim=0)
return masks.unsqueeze(1)
plt.imshow(create_masks(1)[0][0], cmap='gray')
plt.show()
The following class defines our Cat Dataset
class CatDataset(Dataset):
def __init__(self, all_imgs, transforms=None) -> None:
super().__init__()
self.transforms = transforms
self.all_imgs = all_imgs
def __len__(self):
return len(self.all_imgs)
def __getitem__(self, index):
image = self.all_imgs[index]
# 1. Transform the image as needed
transformed_img = self.transforms(image)
if transformed_img.shape[0] == 1:
transformed_img = transformed_img.repeat(3, 1, 1)
# 2. Create copy of image before masking
ground_truth_image = deepcopy(transformed_img)
# 3. Create center mask depending of transformed image
mask = create_masks(1)[0]
transformed_img = (1 - mask) * transformed_img
return transformed_img, ground_truth_image
While composing our dataset, we experiments with several transformations, including what image resolution to train on, various crops, and flips. We landed on the resolution of 64x64 because it would allow the images to retain just enough detail to make out facial features (ie. eyes, nose, mouth), as opposed to smaller resolutions like 32x32 that would make the images almost unrecognizable as cats.
transforms = v2.Compose([
v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
v2.Resize(size=(64, 64))
])
transformed_train_dataset = CatDataset(train_dataset, transforms=transforms)
transformed_val_dataset = CatDataset(val_dataset, transforms=transforms)
transformed_test_dataset = CatDataset(test_dataset, transforms=transforms)
# Grab pictures of just cats, some other pictures have their owners in the background
image0, image1, image2 = transformed_train_dataset[1], transformed_train_dataset[3], transformed_train_dataset[10]
# Display training images and their ground truth images
f, ax = plt.subplots(3, 2, figsize=(12, 12))
ax[0, 0].imshow(np.transpose(image0[0], (1, 2, 0)))
ax[0, 1].imshow(np.transpose(image0[1], (1, 2, 0)))
ax[1, 0].imshow(np.transpose(image1[0], (1, 2, 0)))
ax[1, 1].imshow(np.transpose(image1[1], (1, 2, 0)))
ax[2, 0].imshow(np.transpose(image2[0], (1, 2, 0)))
ax[2, 1].imshow(np.transpose(image2[1], (1, 2, 0)))
ax[0, 0].set_title("Training Images");
ax[0, 1].set_title("Ground Truth Images");
train_dataset, val_dataset, test_dataset = transformed_train_dataset, transformed_val_dataset, transformed_test_dataset
BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
The special type of GANs we plan to use is an DCGANs, which is a variation of the deep learning models composed of two convolutional neural networks — a generator and a discriminator — trained in opposition. In image processing tasks like inpainting, super-resolution, or style transfer, DCGANs learn to generate realistic images by mimicking the distribution of real images.
The generator aims to produce realistic images that can fool the discriminator. For image processing tasks like inpainting, the generator is conditioned to a masked input image with the center noise.
In the end, the goal is to minimize the loss of our model, which from lec12, we know to be:
We experimented various architectures, referencing various academic papers for inspiration. We took inspiration from S. Iizuka, E. Simo-Serra, H. Ishikawa, modifying the architecture to fit our needs. One notable difference we chose to not take from the paper was the use of 2 discriminators, a global one and local one. To simplify our training we instead had one regular discriminator that functions as a simple CNN to classify between authentic and generated images.
The generator uses an encoder-decoder structure with convolutional layers to extract spatial features and deconvolutional layers to reconstruct the full image. Dilated convolutions expand the receptive field, helping the model capture broader context for coherent inpainting. Batch normalization stabilizes training and improves convergence. The discriminator, designed as a patch-based classifier, uses convolution and down-sampling to evaluate image realism, focusing on structural and textural consistency. This adversarial setup pushes the generator to create realistic, context-aware content. Together, these components enable effective learning of both local details and global patterns for high-quality image inpainting.
class Generator(nn.Module):
"""
This is the generator class. The generator takes in a masked image and its job is to inpaint (fill in) the
missing part of the masked image.
The architecture of the generator is inspired by the generator in the article mentioned
earlier:
1) 4x64x64 -> 64x64x64
2) 64x64x64 -> 128x32x32
3) 128x32x32 -> 256x16x16
4) 256x16x16 -> 256x16x16
5) 256x16x16 -> 256x14x14
6) 256x14x14 -> 256x8x8
7) 256x8x8 -> 256x8x8
8) 256x8x8 -> 128x16x16
9) 128x16x16 -> 128x16x16
10) 128x16x16 -> 64x32x32
11) 64x32x32 -> 64x64x64
12) 64x64x64 -> 64x64x64
13) 64x64x64 -> 3x64x64
"""
def __init__(self, im_channels):
super().__init__()
# Encoder
self.conv1 = nn.Conv2d(im_channels + 1, 64, 5, stride=1, padding=2)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
# Dilation
self.conv5 = nn.Conv2d(256, 256, 3, stride=1, padding=1, dilation=2)
self.conv6 = nn.Conv2d(256, 256, 3, stride=1, padding=1, dilation=4)
self.conv7 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
# Decoder
self.deconv8 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
self.conv9 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.deconv10 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
self.deconv11 = nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1)
self.conv12 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv13 = nn.Conv2d(64, 3, 3, stride=1, padding=1)
# Batch norms for each layer
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(256)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm2d(256)
self.bn6 = nn.BatchNorm2d(256)
self.bn7 = nn.BatchNorm2d(256)
self.bn8 = nn.BatchNorm2d(128)
self.bn9 = nn.BatchNorm2d(128)
self.bn10 = nn.BatchNorm2d(64)
self.bn11 = nn.BatchNorm2d(64)
self.bn12 = nn.BatchNorm2d(64)
# Activation Functions
self.relu = torch.nn.ReLU()
self.sig = nn.Sigmoid()
def forward(self, x):
# Encoder layers
x1 = self.bn1(self.relu(self.conv1(x)))
x2 = self.bn2(self.relu(self.conv2(x1)))
x3 = self.bn3(self.relu(self.conv3(x2)))
x4 = self.bn4(self.relu(self.conv4(x3)))
# Dilation Layers
x5 = self.bn5(self.relu(self.conv5(x4)))
x6 = self.bn6(self.relu(self.conv6(x5)))
x7 = self.bn7(self.relu(self.conv7(x6)))
# Decoder layers
x8 = self.bn8(self.relu(self.deconv8(x7)))
x9 = self.bn9(self.relu(self.conv9(x8)))
x10 = self.bn10(self.relu(self.deconv10(x9)))
x11 = self.bn11(self.relu(self.deconv11(x10)))
x12 = self.bn12(self.relu(self.conv12(x11)))
output = self.sig(self.conv13(x12))
return output
class Discriminator(nn.Module):
"""
This is the discriminator class. The discriminator takes in an image and its job is to classify whether the image is generated(fake)
or real (non-generated).
The architecture of the network is the same as a simple CNN classifier
with the following architecture:
1) 3x64x64 -> 64x32x32
2) 64x32x32 -> 128x16x16
3) 128x16X16 -> 256x8x8
4) 256x8x8 -> 512x4x4
5) 512x4x4 -> 512x1x1
6) 512 -> 512
7) 512 -> 512
8) 512 -> 256
9) 256 -> 10
10) 10 -> 1
"""
def __init__(self):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
# Batch norm layers
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(256)
self.bn4 = nn.BatchNorm2d(512)
# Linear layers
self.fc1 = nn.Linear(512, 512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, 10)
self.fc5 = nn.Linear(10, 1)
# Activation functions and pooling
self.down_sample = nn.AvgPool2d(2)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Forwad pass through the layers
x = self.down_sample(self.conv1(x))
x = self.relu(x)
x = self.bn1(x)
x = self.down_sample(self.conv2(x))
x = self.relu(x)
x = self.bn2(x)
x = self.down_sample(self.conv3(x))
x = self.relu(x)
x = self.bn3(x)
x = self.down_sample(self.conv4(x))
x = self.relu(x)
x = self.bn4(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.relu(x)
x = self.fc5(x)
x = self.sigmoid(x)
return x
def train(g, d, train_loader, val_loader, criterion, optimizer_g, optimizer_d, epochs, verbose=False, show_images=False):
"""
Trains model for image inpainting.
Args:
g: The generator model.
d: The discriminator model.
train_loader: DataLoader for the training dataset.
val_loader: DataLoader for the validation dataset.
criterion: The loss function.
optimizer_g: Optimizer for the generator.
optimizer_d: Optimizer for the discriminator.
epochs: The number of training epochs.
verbose: If True, prints training progress.
show_images: If True, displays generated images during validation.
"""
# Total losses for epoch x batches
g_losses = []
d_losses = []
total_steps = len(train_loader)
val_iter = iter(val_loader)
for epoch in range(epochs):
g.train()
d.train()
# Average loss per batch
g_avg_loss_per_batch = 0
d_avg_loss_per_batch = 0
for i, batch in enumerate(train_loader):
im_masked, im_truth = batch
# Move data to GPU
im_masked = im_masked.to(device)
im_truth = im_truth.to(device)
# Add the masks as a separate channel
N = batch[0].shape[0]
masks = create_masks(N).to(device)
im_masked = torch.cat((im_masked, masks), dim=1)
# --- Train Discriminator ---
d.zero_grad()
# Train discriminator with ground truth images
outputs = d(im_truth).view(-1)
true_labels = torch.ones(N).to(device)
d_err = criterion(outputs, true_labels)
# Train discriminator with fake images
fake_im = g(im_masked)
inpainted = fake_im * masks[0] + im_masked[:, :3, :, :]
output = d(inpainted.detach()).view(-1)
fake_labels = torch.zeros(N).to(device)
d_err += criterion(output, fake_labels)
d_err.backward()
optimizer_d.step()
# --- Train Generator ---
g.zero_grad()
fake_im = g(im_masked)
inpainted = fake_im * masks[0] + im_masked[:, :3, :, :]
outputs = d(inpainted).view(-1)
g_err = criterion(outputs, true_labels)
g_err.backward()
optimizer_g.step()
# Keep track of loss
g_avg_loss_per_batch += g_err.item()
d_avg_loss_per_batch += d_err.item()
g_losses.append(g_err.item())
d_losses.append(d_err.item())
if show_images and i == 0:
g.eval()
with torch.no_grad():
val_batch = next(val_iter)
val_masked, val_truth = val_batch
# Move validation data to GPU
val_masked = val_masked.to(device)
val_truth = val_truth.to(device)
N = val_masked.shape[0]
val_masks = create_masks(N).to(device)
val_input = torch.cat((val_masked, val_masks), dim=1)
# Inpainting
fake_val = g(val_input)
inpainted = val_masked * (1 - val_masks) + fake_val * val_masks
# Unpainted
unpainted = val_truth * (1 - val_masks)
# Get first 4 samples
idx = slice(0, 4)
sample_masked = val_masked[idx]
sample_fake = fake_val[idx]
sample_inpainted = inpainted[idx]
sample_truth = val_truth[idx]
images = [sample_masked, sample_fake, sample_inpainted, sample_truth]
labels = ["Masked Input", "Generated Output", "Inpainted", "Ground Truth"]
# Plot
fig, axs = plt.subplots(len(images), 4, figsize=(16, 10))
for row in range(len(images)):
for col in range(4):
img = images[row][col].cpu()
img = img.permute(1, 2, 0) # CxHxW → HxWxC
axs[row, col].imshow(img.numpy(), cmap='gray' if img.shape[2] == 1 else None)
axs[row, col].axis('off')
if col == 0:
axs[row, col].set_title(labels[row], fontsize=10, loc='left')
plt.tight_layout()
plt.show()
g.train()
if verbose:
print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{total_steps}], '
f'Discriminator Loss: {(d_avg_loss_per_batch/total_steps):.4f}, '
f'Generator Loss: {(g_avg_loss_per_batch/total_steps):.4f}')
return g_losses, d_losses
We first do a basic training of our model. We will observe how the generator fills in the masked hole after each epoch and compare it to the ground truth.
g = Generator(im_channels=3).to(device)
d = Discriminator().to(device)
optimizer_g = optim.Adam(g.parameters(), lr=0.0001)
optimizer_d = optim.Adam(d.parameters(), lr=0.0001)
criterion = nn.BCELoss()
num_epochs = 10
g_losses, d_losses = train(g, d, train_loader, val_loader, criterion, optimizer_g, optimizer_d, epochs=10, verbose=True, show_images=True)
We find that the model's inpainting quality greatly improves with training. Reconstructing background elements is difficult in the beginning, frequently resulting in mismatched colors and shading, but these features improve significantly in later eras. Simpler cat body parts like the belly and legs are generated by the model fairly well, but it constantly fails to produce more complex facial features like the eyes, nose, and mouth—though it can occasionally produce eyes. One significant drawback is its inability to handle human features, like hands, which is still a known problem in generative modeling. All things considered, the picture clearly demonstrates the model's growing comprehension of texture and structure over time.
# Plot the loss graph
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="G")
plt.plot(d_losses, label="D")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Losses")
plt.legend()
plt.grid(True)
plt.show()
Throughout training, this graph displays the discriminator's and generator's loss values. On average, the discriminator loss shows a distinct downward trend; the lower boundary of its oscillating curve gradually drops, taking the shape of a cone that gets smaller over time. On the other hand, the generator's baseline only slightly drifts downward, and its loss stays largely constant. As training goes on, both loss curves exhibit an intriguing pattern of rising volatility with larger value fluctuations. The source of this instability is unknown, but it doesn't seem to have a detrimental effect on the generated inpainted images' visual quality, indicating that average loss behavior is still trending in a productive direction.
torch.save(g.state_dict(), f"initial_generator.pt")
torch.save(d.state_dict(), f"initial_discriminator.pt")
The dynamic interaction between a generator and a discriminator is captured by the rich mathematical framework that underpins Generative Adversarial Networks (GANs). The mathematical viewpoint presented in this blog, which offers a formal and intuitive understanding of the adversarial training setup, served as our source of inspiration. The objective function formalizes the min-max optimization problem at the core of this framework:
This formulation supports our intuition: the generator aims to learn the true data distribution by producing outputs that the discriminator cannot distinguish from real data. To support this goal, we explored several loss functions beyond the adversarial objective—specifically L1, L2, and Huber loss. L1 loss (mean absolute error) is more robust to outliers and encourages sparsity, while L2 loss (mean squared error) penalizes larger errors more heavily, leading to smoother predictions. Huber loss offers a balance between the two, behaving like L2 for small errors and L1 for large ones. By experimenting with these loss functions, we aimed to better guide the generator during training and enhance the quality of inpainted outputs.
def get_trained_model_or_train(experiment_name, experiment_params):
model_dir = f"models/{experiment_name}"
if os.path.exists(model_dir):
print(f"Found Experiment {experiment_name}")
trained_generator = Generator(im_channels=3).to(device)
trained_discriminator = Discriminator().to(device)
trained_generator.load_state_dict(torch.load(f"{model_dir}/generator.pt"))
trained_discriminator.load_state_dict(torch.load(f"{model_dir}/discriminator.pt"))
# Load losses if they exist
g_losses_path = f"{model_dir}/g_losses.pt"
d_losses_path = f"{model_dir}/d_losses.pt"
g_losses = torch.load(g_losses_path) if os.path.exists(g_losses_path) else []
d_losses = torch.load(d_losses_path) if os.path.exists(d_losses_path) else []
return trained_generator, trained_discriminator, g_losses, d_losses
else:
generator = Generator(im_channels=3).to(device)
discriminator = Discriminator().to(device)
g_losses, d_losses = train(g=generator, d=discriminator, **experiment_params)
os.makedirs(model_dir, exist_ok=True) # Create directory if it doesn't exist
torch.save(generator.state_dict(), f"{model_dir}/generator.pt")
torch.save(discriminator.state_dict(), f"{model_dir}/discriminator.pt")
# Save losses
torch.save(g_losses, f"{model_dir}/g_losses.pt")
torch.save(d_losses, f"{model_dir}/d_losses.pt")
return generator, discriminator, g_losses, d_losses
optimizer_g = optim.Adam(g.parameters(), lr=0.0001)
optimizer_d = optim.Adam(d.parameters(), lr=0.0001)
criterion = nn.L1Loss()
num_epochs = 10
g, d, g_losses, d_losses = get_trained_model_or_train(
experiment_name="l1_loss",
experiment_params={
"train_loader": train_loader,
"val_loader": val_loader,
"criterion": criterion,
"optimizer_g": optimizer_g,
"optimizer_d": optimizer_d,
"epochs": num_epochs,
"verbose": True,
"show_images": False
}
)
MSELoss and L1Loss not seemed to work with our model?? Could just scrap later
optimizer_g = optim.Adam(g.parameters(), lr=0.0001)
optimizer_d = optim.Adam(d.parameters(), lr=0.0001)
criterion = nn.MSELoss()
num_epochs = 10
g, d, g_losses, d_losses = get_trained_model_or_train(
experiment_name="l2_loss",
experiment_params={
"train_loader": train_loader,
"val_loader": val_loader,
"criterion": criterion,
"optimizer_g": optimizer_g,
"optimizer_d": optimizer_d,
"epochs": num_epochs,
"verbose": True,
"show_images": False
}
)
optimizer_g = optim.Adam(g.parameters(), lr=0.0001)
optimizer_d = optim.Adam(d.parameters(), lr=0.0001)
criterion = nn.HuberLoss()
num_epochs = 10
g, d, g_losses, d_losses = get_trained_model_or_train(
experiment_name="huber_loss",
experiment_params={
"train_loader": train_loader,
"val_loader": val_loader,
"criterion": criterion,
"optimizer_g": optimizer_g,
"optimizer_d": optimizer_d,
"epochs": num_epochs,
"verbose": True,
"show_images": False
}
)
# Load pre-trained models
g = Generator(im_channels=3).to(device)
d = Discriminator().to(device)
g.load_state_dict(torch.load("initial_generator.pt", map_location=device))
d.load_state_dict(torch.load("initial_discriminator.pt", map_location=device))
TODO? maybe write a function that will display the inpainting capabilities of each of the trained models depending on their loss function. Similar to how the "validation" cell is doing it below. but prob can only do this if can get the different loss functions to work 😞
# Shows results after training
num_evaluation = 10
num_image_display = 5
val_iter = iter(val_loader)
# Loss
g_tot_loss = 0
d_tot_loss = 0
num_total = 1
# Validate GAN
f, ax = plt.subplots(num_image_display, 2, figsize=(12, 12))
ax[0, 0].set_title("Inpainted Images (Fake)")
ax[0, 1].set_title("Ground Truth Images")
for i in range(num_evaluation):
g.eval()
d.eval()
with torch.no_grad():
val_batch = next(val_iter)
val_masked, val_truth = val_batch
N = len(val_masked)
# Move validation data to GPU
val_masked = val_masked.to(device)
val_truth = val_truth.to(device)
N = len(val_masked)
# Get labels
true_labels = torch.ones(N).to(device)
fake_labels = torch.zeros(N).to(device)
val_masks = create_masks(N).to(device)
val_masked = torch.cat((val_masked, val_masks), dim=1)
# Get the inpainted image using the generator
fake_im = g(val_masked)
inpainted = fake_im * val_masks[0] + val_masked[:, :3, :, :]
outputs = d(inpainted).view(-1)
d_tot_loss += criterion(outputs, fake_labels).item()
g_tot_loss += criterion(outputs, true_labels).item()
num_total += 1
if (i < num_image_display):
# Display first sample
ax[i, 0].imshow(inpainted[0].cpu().permute(1, 2, 0))
ax[i, 1].imshow(val_truth[0].cpu().permute(1, 2, 0))
plt.show()
print(f'Number of Validation Images: {num_total * N}, '
f'Average Discriminator Loss: {d_tot_loss / num_total:.4f}, '
f'Average Generator Loss: {d_tot_loss / num_total:.4f}')
Our DCGAN model demonstrated promising results in the object removal and inpainting task. It was able to generate visually plausible completions of the missing regions, especially for relatively uniform backgrounds like fur. However, the model still has clear room for improvement. Recent advances such as Stable Diffusion and other Transformer-based architectures have shown significant gains in generative tasks. Incorporating these state-of-the-art models, though computationally expensive, could potentially lead to more coherent and higher-quality results.
At the start of our project, we simplified the problem by always masking a fixed square at the center of each image. This made the task more controlled, but less practical for real-world applications. A more flexible approach would involve user-defined masks of arbitrary shape and size. Supporting custom masks would make the model far more robust and widely usable across different tasks. Future work could involve incorporating spatial attention or segmentation cues to guide such inpainting.
Training the model involved experimenting with several loss functions, including L1, L2, and Huber loss, which are commonly used for regression tasks. However, these losses did not yield stable or visually pleasing outputs in our setting. We eventually settled on binary cross-entropy (BCE) loss due to its alignment with the adversarial training objective and improved qualitative results. Still, finding the right balance between adversarial loss and reconstruction loss remains a challenge. Future exploration into hybrid loss strategies or perceptual loss may help improve output fidelity.
One significant challenge we encountered was the inconsistency in our dataset of cat images. Many images were not standardized—cats varied in position, lighting, and scale, and some images even included human hands or other distractions. These inconsistencies made it harder for the model to learn a clean background distribution. Moving forward, it may be beneficial to either cleanse the current dataset or switch to a more standardized and curated one. Starting with an easier dataset may also allow the model to learn more effectively before tackling more complex scenarios.
testing that we can train with different loss functions
ok took out training with different losses outside of the get_trained_model_or_train function. so we know at least training on L2/MSE loss works. so need to fix get_trained_model_or_train
g = Generator(im_channels=3).to(device)
d = Discriminator().to(device)
optimizer_g = optim.Adam(g.parameters(), lr=0.0001)
optimizer_d = optim.Adam(d.parameters(), lr=0.0001)
criterion = nn.MSELoss(reduction="mean")
num_epochs = 10
g_losses, d_losses = train(g, d, train_loader, val_loader, criterion, optimizer_g, optimizer_d, epochs=10, verbose=True, show_images=True)
# Plot the loss graph
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="G")
plt.plot(d_losses, label="D")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Losses")
plt.legend()
plt.grid(True)
plt.show()